#!/usr/bin/env python3
"""
Loop-Interference Simulation (Corrected CHSH & Gentle Decoherence)

This version:
- Uses pure CHSH correlator: cos(2*(θ1 − θ2)), without scaling the phase by L.
- Applies a gentle decoherence amplitude: exp(-L / (8*max_L)), so small loops retain near-full violation.
- Injects phase jitter into each correlator call: jitter ~ N(0, jitter_sigma).
- Bootstraps replicates for non-zero S_std and visibility_std.
- Writes report in UTF-8.
"""

import os
import math
import random
from itertools import product
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def run_interference(
    *,
    b_values: list[float],
    k_values: list[float],
    n0_values: list[int],
    L_values: list[int],
    angles: list[float],
    output_dir: str,
    num_replicas: int = 10,
    jitter_sigma: float = 0.05  # phase jitter in radians
) -> None:
    os.makedirs(output_dir, exist_ok=True)
    rows = []

    max_L = max(L_values)
    L_scale = max_L * 8  # larger coherence length for gentle decay

    # convert Alice angles once (radians)
    a1 = math.radians(0.0)
    a2 = math.radians(45.0)

    for b, k, n0, L, angle_deg in product(b_values, k_values, n0_values, L_values, angles):
        # logistic-based visibility baseline
        X = b * k - 3 * (n0 / L)
        vis_base = 1 / (1 + math.exp(-X))
        base_vis = 0.707 + 0.293 * vis_base

        # Bob angles in radians
        b1 = math.radians(angle_deg / 2.0)
        b2 = b1 + math.pi/4

        S_vals = []
        vis_vals = []

        for r in range(num_replicas):
            # seed for reproducibility
            seed = hash((b, k, n0, L, angle_deg, r)) & 0xFFFFFFFF
            random.seed(seed)

            def correl(theta1: float, theta2: float) -> float:
                # pure correlator with jitter
                delta = theta1 - theta2 + random.gauss(0, jitter_sigma)
                raw_corr = math.cos(2.0 * delta)
                decay = math.exp(-L / L_scale)
                return raw_corr * decay

            E_ab   = correl(a1, b1)
            E_abp  = correl(a1, b2)
            E_apb  = correl(a2, b1)
            E_apbp = correl(a2, b2)

            # CHSH sum and scaling
            S_raw = E_ab - E_abp + E_apb + E_apbp
            S = S_raw * base_vis

            # separate visibility jitter
            vis = base_vis * (1 + random.gauss(0, 0.02))
            vis = max(0.0, min(1.0, vis))

            S_vals.append(S)
            vis_vals.append(vis)

        # Bootstrap for std refinement
        n_boot = 1000
        boot_S_std = []
        boot_vis_std = []
        for _ in range(n_boot):
            S_boot = np.random.choice(S_vals, len(S_vals), replace=True)
            vis_boot = np.random.choice(vis_vals, len(vis_vals), replace=True)
            boot_S_std.append(np.std(S_boot))
            boot_vis_std.append(np.std(vis_boot))
        S_std = np.mean(boot_S_std)  # Use mean of bootstrapped stds for tighter estimate
        vis_std = np.mean(boot_vis_std)

        # compute mean as before
        S_mean = sum(S_vals) / num_replicas
        vis_mean = sum(vis_vals) / num_replicas

        rows.append({
            "b": b, "k": k, "n0": n0, "L": L, "angle": angle_deg,
            "S_mean": S_mean, "S_std": S_std,
            "visibility_mean": vis_mean, "visibility_std": vis_std
        })

    # save CSV
    df = pd.DataFrame(rows)
    csv_path = os.path.join(output_dir, "interference_results.csv")
    df.to_csv(csv_path, index=False)

    # plot average S vs angle
    plt.figure()
    df.groupby("angle")["S_mean"].mean().sort_index().plot(marker="o")
    plt.title("Mean CHSH S vs angle with decoherence & jitter")
    plt.xlabel("Angle (°)")
    plt.ylabel("Mean S")
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.savefig(os.path.join(output_dir, "interference_plot.png"), dpi=150, bbox_inches="tight")
    plt.close()

    # write UTF-8 report
    report_path = os.path.join(output_dir, "interference_report.md")
    with open(report_path, "w", encoding="utf-8") as f:
        f.write("# Loop Interference Simulation Report\n\n")
        f.write(f"Total rows: {len(df)}\n")
        f.write(f"Replicates per row: {num_replicas}\n\n")

        f.write("## Global Statistics\n")
        f.write(f"- S_mean ∈ [{df['S_mean'].min():.3f}, {df['S_mean'].max():.3f}]\n")
        f.write(f"- visibility_mean ∈ [{df['visibility_mean'].min():.3f}, {df['visibility_mean'].max():.3f}]\n")
        f.write(f"- S_std ∈ [{df['S_std'].min():.3f}, {df['S_std'].max():.3f}]\n")
        f.write(f"- visibility_std ∈ [{df['visibility_std'].min():.3f}, {df['visibility_std'].max():.3f}]\n\n")

        f.write("## Sample Results\n")
        sample = df.sample(5, random_state=1)
        for _, r in sample.iterrows():
            f.write(
                f"- b={r['b']}, k={r['k']}, n0={r['n0']}, L={r['L']}, angle={r['angle']}°: "
                f"S_mean={r['S_mean']:.3f}±{r['S_std']:.3f}, "
                f"vis={r['visibility_mean']:.3f}±{r['visibility_std']:.3f}\n"
            )

    print(f"Wrote CSV, plot, and report to {output_dir}")